/*
 * Copyright (c) 2020-2023, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#pragma once

#include <cuda_runtime_api.h>

#include "moe_gemm_kernels.h"

namespace tensorrt_llm::kernels::cutlass_kernels {

// Keep in sync with the signature generated by generate_kernels.py
template <typename Arch, typename T, typename WeightType, typename OutputType, typename EpilogueTag,
          TmaWarpSpecializedGroupedGemmInput::EpilogueFusion FUSION, typename TileShape,
          typename ClusterShape, bool IsMXFPX, bool BIAS>
void tma_warp_specialized_generic_moe_gemm_kernelLauncher(
    TmaWarpSpecializedGroupedGemmInput hopper_input, int num_experts, int multi_processor_count,
    cudaStream_t stream, int* kernel_occupancy, size_t* workspace_size);

}  // namespace tensorrt_llm::kernels::cutlass_kernels
